{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "klGNgWREsvQv"
      },
      "source": [
        "##### Copyright 2018 The TF-Agents Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nQnmcm0oI1Q-"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HNtBC6Bbb1YU"
      },
      "source": [
        "# REINFORCE エージェント\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/agents/tutorials/6_reinforce_tutorial\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/agents/tutorials/6_reinforce_tutorial.ipynb\">     <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colabで実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/agents/tutorials/6_reinforce_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/agents/tutorials/6_reinforce_tutorial.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZOUOQOrFs3zn"
      },
      "source": [
        "## はじめに"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cKOCZlhUgXVK"
      },
      "source": [
        "この例は、<a>DQN チュートリアル</a>のように、Cartpole 環境で TF-Agents ライブラリを使用して [REINFORCE（強化）](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf)エージェントをトレーニングする方法を示します。\n",
        "\n",
        "![Cartpole environment](https://github.com/tensorflow/docs-l10n/blob/master/site/en-snapshot/agents/tutorials/images/cartpole.png?raw=true)\n",
        "\n",
        "ここでは、トレーニング、評価、データ収集に使用する強化学習（RL）パイプラインの全コンポーネントについて説明します。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1u9QVVsShC9X"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I5PNmEzIb9t4"
      },
      "source": [
        "以下の依存関係をインストールしていない場合は、実行します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KEHR2Ui-lo8O"
      },
      "outputs": [],
      "source": [
        "!sudo apt-get install -y xvfb ffmpeg\n",
        "!pip install 'gym==0.10.11'\n",
        "!pip install 'imageio==2.4.0'\n",
        "!pip install PILLOW\n",
        "!pip install 'pyglet==1.3.2'\n",
        "!pip install pyvirtualdisplay\n",
        "!pip install tf-agents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sMitx5qSgJk1"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import base64\n",
        "import imageio\n",
        "import IPython\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import PIL.Image\n",
        "import pyvirtualdisplay\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from tf_agents.agents.reinforce import reinforce_agent\n",
        "from tf_agents.drivers import dynamic_step_driver\n",
        "from tf_agents.environments import suite_gym\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.eval import metric_utils\n",
        "from tf_agents.metrics import tf_metrics\n",
        "from tf_agents.networks import actor_distribution_network\n",
        "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n",
        "from tf_agents.trajectories import trajectory\n",
        "from tf_agents.utils import common\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()\n",
        "\n",
        "\n",
        "# Set up a virtual display for rendering OpenAI gym environments.\n",
        "display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LmC0NDhdLIKY"
      },
      "source": [
        "## ハイパーパラメータ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HC1kNrOsLSIZ"
      },
      "outputs": [],
      "source": [
        "env_name = \"CartPole-v0\" # @param {type:\"string\"}\n",
        "num_iterations = 250 # @param {type:\"integer\"}\n",
        "collect_episodes_per_iteration = 2 # @param {type:\"integer\"}\n",
        "replay_buffer_capacity = 2000 # @param {type:\"integer\"}\n",
        "\n",
        "fc_layer_params = (100,)\n",
        "\n",
        "learning_rate = 1e-3 # @param {type:\"number\"}\n",
        "log_interval = 25 # @param {type:\"integer\"}\n",
        "num_eval_episodes = 10 # @param {type:\"integer\"}\n",
        "eval_interval = 50 # @param {type:\"integer\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VMsJC3DEgI0x"
      },
      "source": [
        "## 環境\n",
        "\n",
        "強化学習の環境は、解決しようとしているタスクまたは問題を表しています。標準環境は、`suites`を使用して TF-Agent で簡単に作成できます。OpenAI Gym、Atari、DM Control などのソースから環境を読み込むには異なる`suites`を使用し、これには文字列の環境名が与えられます。\n",
        "\n",
        "では、OpenAI Gym スイートから CartPole 環境を読み込みましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pYEz-S9gEv2-"
      },
      "outputs": [],
      "source": [
        "env = suite_gym.load(env_name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IIHYVBkuvPNw"
      },
      "source": [
        "この環境をレンダリングして、どのように見えるかを確認できます。台車の上に回転軸を固定した棒を立て、その棒が倒れないように台車を左右に動かすことが目的です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RlO7WIQHu_7D"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "env.reset()\n",
        "PIL.Image.fromarray(env.render())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B9_lskPOey18"
      },
      "source": [
        "`time_step = environment.step(action)`文は環境で`action`を取ります。返される`TimeStep`タプルには、環境の次の観測とその行動の報酬を含みます。環境の`time_step_spec()`メソッドと`action_spec()`メソッドは、 `time_step`と`action`の仕様（タイプ、形状、境界）をそれぞれ返します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "exDv57iHfwQV"
      },
      "outputs": [],
      "source": [
        "print('Observation Spec:')\n",
        "print(env.time_step_spec().observation)\n",
        "print('Action Spec:')\n",
        "print(env.action_spec())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eJCgJnx3g0yY"
      },
      "source": [
        "つまり、観測は、台車の位置と速度、軸の角度位置と速度から成る 4 つの浮動小数点数の配列であることが分かります。2 つの行動（左に動かす、または右に動かす）のみが可能なので、`action_spec`はスカラーで、0 は「左に動かす」を意味し、1 は「右に移動する」を意味します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V2UGR5t_iZX-"
      },
      "outputs": [],
      "source": [
        "time_step = env.reset()\n",
        "print('Time step:')\n",
        "print(time_step)\n",
        "\n",
        "action = np.array(1, dtype=np.int32)\n",
        "\n",
        "next_time_step = env.step(action)\n",
        "print('Next time step:')\n",
        "print(next_time_step)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuUqXAVmecTU"
      },
      "source": [
        "通常は、トレーニング用と評価用の 2 つの環境を作成します。ほとんどの環境は純粋な python で記述されていますが、`TFPyEnvironment` ラッパーを使用すると容易に TensorFlow に変換できます。元の環境の API は numpy 配列を使用していますが、`TFPyEnvironment` はこれらを `Tensors` に変換したりその逆を可能にするので、より簡単に TensorFlow ポリシーやエージェントと相互作用できます。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xp-Y4mD6eDhF"
      },
      "outputs": [],
      "source": [
        "train_py_env = suite_gym.load(env_name)\n",
        "eval_py_env = suite_gym.load(env_name)\n",
        "\n",
        "train_env = tf_py_environment.TFPyEnvironment(train_py_env)\n",
        "eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E9lW_OZYFR8A"
      },
      "source": [
        "## 行動\n",
        "\n",
        "強化学習問題の解決に使用するアルゴリズムは、`Agent`として表現されます。TF-Agents は、[DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)、[DDPG](https://arxiv.org/pdf/1509.02971.pdf)、[TD3](https://arxiv.org/pdf/1802.09477.pdf)、[PPO](https://arxiv.org/abs/1707.06347)、[SAC](https://arxiv.org/abs/1801.01290) などのさまざまな`Agents`の標準実装を提供します。\n",
        "\n",
        "REINFORCE エージェントを作成するためには、まずは任意の環境の観測から行動の予測ができるように学習可能な`Actor Network`が必要です。\n",
        "\n",
        "`Actor Network`は、観測と行動の仕様を使用して簡単に作成できます。ネットワーク内のレイヤーの指定が可能で、この例では`fc_layer_params`引数が各隠れレイヤーのサイズを表現する`ints`のタプルを設定しています（上記ハイパーパラメータのセクションをご覧ください）。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TgkdEPg_muzV"
      },
      "outputs": [],
      "source": [
        "actor_net = actor_distribution_network.ActorDistributionNetwork(\n",
        "    train_env.observation_spec(),\n",
        "    train_env.action_spec(),\n",
        "    fc_layer_params=fc_layer_params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z62u55hSmviJ"
      },
      "source": [
        "また、先ほど作成したネットワークをトレーニングする`optimizer` と、ネットワークが更新された回数を追跡する`train_step_counter`変数も必要です。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jbY4yrjTEyc9"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)\n",
        "\n",
        "train_step_counter = tf.compat.v2.Variable(0)\n",
        "\n",
        "tf_agent = reinforce_agent.ReinforceAgent(\n",
        "    train_env.time_step_spec(),\n",
        "    train_env.action_spec(),\n",
        "    actor_network=actor_net,\n",
        "    optimizer=optimizer,\n",
        "    normalize_returns=True,\n",
        "    train_step_counter=train_step_counter)\n",
        "tf_agent.initialize()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I0KLrEPwkn5x"
      },
      "source": [
        "## ポリシー\n",
        "\n",
        "TF-Agent では、ポリシーとは強化学習のポリシーの標準的な概念を表し、任意の`time_step`は行動または行動の分布を生成します。主なメソッドは`policy_step = policy.step(time_step)`で、`policy_step`は名前付きのタプル`PolicyStep(action, state, info)`です。`policy_step.action`は環境に適用される`action`で、`state`はステートフル（RNN）ポリシーの状態を表し、`info`には行動のログ確率などの補助情報が含まれる場合があります。\n",
        "\n",
        "エージェントには、2 つのポリシーが含まれています。これは評価/デプロイに使用するメインのポリシー（agent.policy）と、データ収集に使用するもう 1 つのポリシー（agent.collect_policy）です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BwY7StuMkuV4"
      },
      "outputs": [],
      "source": [
        "eval_policy = tf_agent.policy\n",
        "collect_policy = tf_agent.collect_policy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "94rCXQtbUbXv"
      },
      "source": [
        "## メトリクスと評価\n",
        "\n",
        "ポリシーの評価に使用される最も一般的なメトリックは、平均リターンです。リターンは、エピソードの環境でポリシーを実行中に取得した報酬の合計であり、通常は複数のエピソード間で平均化します。平均リターンのメトリックは以下のように計算できます。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bitzHo5_UbXy"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "def compute_avg_return(environment, policy, num_episodes=10):\n",
        "\n",
        "  total_return = 0.0\n",
        "  for _ in range(num_episodes):\n",
        "\n",
        "    time_step = environment.reset()\n",
        "    episode_return = 0.0\n",
        "\n",
        "    while not time_step.is_last():\n",
        "      action_step = policy.action(time_step)\n",
        "      time_step = environment.step(action_step.action)\n",
        "      episode_return += time_step.reward\n",
        "    total_return += episode_return\n",
        "\n",
        "  avg_return = total_return / num_episodes\n",
        "  return avg_return.numpy()[0]\n",
        "\n",
        "\n",
        "# Please also see the metrics module for standard implementations of different\n",
        "# metrics."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NLva6g2jdWgr"
      },
      "source": [
        "## 再生バッファ\n",
        "\n",
        "環境から収集したデータを追跡するには、TFUniformReplayBuffer を使用します。この再生バッファは格納するテンソルを記述する仕様を使って構築され、`tf_agent.collect_data_spec`を使用してエージェントから取得することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vX2zGUWJGWAl"
      },
      "outputs": [],
      "source": [
        "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
        "    data_spec=tf_agent.collect_data_spec,\n",
        "    batch_size=train_env.batch_size,\n",
        "    max_length=replay_buffer_capacity)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZGNTDJpZs4NN"
      },
      "source": [
        "ほとんどのエージェントの場合、`collect_data_spec`は`Trajectory`という名前付きタプルであり、観測、行動、報酬、その他が含まれています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rVD5nQ9ZGo8_"
      },
      "source": [
        "## データ収集\n",
        "\n",
        "REINFORCE はエピソード全体から学習するため、任意のデータ収集ポリシーを使用してエピソードを収集する関数を定義し、データ（観測、行動、報酬など）を再生バッファに軌跡として保存します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wr1KSAEGG4h9"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "\n",
        "def collect_episode(environment, policy, num_episodes):\n",
        "\n",
        "  episode_counter = 0\n",
        "  environment.reset()\n",
        "\n",
        "  while episode_counter < num_episodes:\n",
        "    time_step = environment.current_time_step()\n",
        "    action_step = policy.action(time_step)\n",
        "    next_time_step = environment.step(action_step.action)\n",
        "    traj = trajectory.from_transition(time_step, action_step, next_time_step)\n",
        "\n",
        "    # Add trajectory to the replay buffer\n",
        "    replay_buffer.add_batch(traj)\n",
        "\n",
        "    if traj.is_boundary():\n",
        "      episode_counter += 1\n",
        "\n",
        "\n",
        "# This loop is so common in RL, that we provide standard implementations of\n",
        "# these. For more details see the drivers module."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBc9lj9VWWtZ"
      },
      "source": [
        "## エージェントのトレーニング\n",
        "\n",
        "トレーニングループには、環境からのデータ収集とエージェントのネットワークの最適化の両方を含みます。途中でエージェントのポリシーを時々評価して、状況を確認します。\n",
        "\n",
        "以下の実行には 3 分ほどかかります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0pTbJ3PeyF-u"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "try:\n",
        "  %%time\n",
        "except:\n",
        "  pass\n",
        "\n",
        "# (Optional) Optimize by wrapping some of the code in a graph using TF function.\n",
        "tf_agent.train = common.function(tf_agent.train)\n",
        "\n",
        "# Reset the train step\n",
        "tf_agent.train_step_counter.assign(0)\n",
        "\n",
        "# Evaluate the agent's policy once before training.\n",
        "avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)\n",
        "returns = [avg_return]\n",
        "\n",
        "for _ in range(num_iterations):\n",
        "\n",
        "  # Collect a few episodes using collect_policy and save to the replay buffer.\n",
        "  collect_episode(\n",
        "      train_env, tf_agent.collect_policy, collect_episodes_per_iteration)\n",
        "\n",
        "  # Use data from the buffer and update the agent's network.\n",
        "  experience = replay_buffer.gather_all()\n",
        "  train_loss = tf_agent.train(experience)\n",
        "  replay_buffer.clear()\n",
        "\n",
        "  step = tf_agent.train_step_counter.numpy()\n",
        "\n",
        "  if step % log_interval == 0:\n",
        "    print('step = {0}: loss = {1}'.format(step, train_loss.loss))\n",
        "\n",
        "  if step % eval_interval == 0:\n",
        "    avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)\n",
        "    print('step = {0}: Average Return = {1}'.format(step, avg_return))\n",
        "    returns.append(avg_return)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "68jNcA_TiJDq"
      },
      "source": [
        "## 視覚化\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aO-LWCdbbOIC"
      },
      "source": [
        "### プロット\n",
        "\n",
        "リターンとグローバルステップをプロットして、エージェントのパフォーマンスを確認できます。`Cartpole-v0`では、棒が立ったままのタイムステップごとに環境は +1 の報酬を提供します。最大ステップ数は 200 であるため、可能な最大リターン値も 200 です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NxtL1mbOYCVO"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "\n",
        "steps = range(0, num_iterations + 1, eval_interval)\n",
        "plt.plot(steps, returns)\n",
        "plt.ylabel('Average Return')\n",
        "plt.xlabel('Step')\n",
        "plt.ylim(top=250)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M7-XpPP99Cy7"
      },
      "source": [
        "### 動画"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9pGfGxSH32gn"
      },
      "source": [
        "各ステップで環境をレンダリングすると、エージェントのパフォーマンスを可視化できます。その前に、この Colab に動画を埋め込む関数を作成しましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ULaGr8pvOKbl"
      },
      "outputs": [],
      "source": [
        "def embed_mp4(filename):\n",
        "  \"\"\"Embeds an mp4 file in the notebook.\"\"\"\n",
        "  video = open(filename,'rb').read()\n",
        "  b64 = base64.b64encode(video)\n",
        "  tag = '''\n",
        "  <video width=\"640\" height=\"480\" controls>\n",
        "    <source src=\"data:video/mp4;base64,{0}\" type=\"video/mp4\">\n",
        "  Your browser does not support the video tag.\n",
        "  </video>'''.format(b64.decode())\n",
        "\n",
        "  return IPython.display.HTML(tag)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9c_PH-pX4Pr5"
      },
      "source": [
        "次のコードは、いくつかのエピソードに渡るエージェントのポリシーを可視化します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "owOVWB158NlF"
      },
      "outputs": [],
      "source": [
        "num_episodes = 3\n",
        "video_filename = 'imageio.mp4'\n",
        "with imageio.get_writer(video_filename, fps=60) as video:\n",
        "  for _ in range(num_episodes):\n",
        "    time_step = eval_env.reset()\n",
        "    video.append_data(eval_py_env.render())\n",
        "    while not time_step.is_last():\n",
        "      action_step = tf_agent.policy.action(time_step)\n",
        "      time_step = eval_env.step(action_step.action)\n",
        "      video.append_data(eval_py_env.render())\n",
        "\n",
        "embed_mp4(video_filename)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "6_reinforce_tutorial.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
